from typing import List, Tuple, Optional
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm

from models.judge import Judge
from attack.score import Score, Text
from models.substitutors import Substitutor, ModernBertSubstitutor, GPTSubstitutor
from models.evaluator import HarmBenchJudge


class GreedySearchOptimizer:
    def __init__(
        self,
        model,
        layer: int,
        target_embedding: np.ndarray,
        num_substitutions: int = 20,
        max_iterations: int = 30, #default 30
        substitutor: Substitutor = GPTSubstitutor,
        judge: bool = True,
        evaluator: bool = True,
        device: str = "cuda",
    ):
        self.model = model
        self.target_embedding = target_embedding
        self.num_substitutions = num_substitutions
        self.max_iterations = max_iterations
        self.device = device['main_device']
        self.substitutor = substitutor
        self.layer = layer
        self.evaluator = evaluator
        
        # Initialize text and substitutor
        self.substitutor = self.substitutor(
            k=self.num_substitutions,
            device=self.device
        )

        if judge:
            self.judge = Judge()
        else:
            self.judge = None
        if evaluator:
            self.evaluator = HarmBenchJudge(device=device['evaluator_device'])
        else:
            self.evaluator = None

        

    def optimize(self, initial_prompt: str) -> List[Tuple[str, float]]:
        """
        Run greedy search optimization
        
        Args:
            initial_prompt: Starting prompt to optimize
            
        Returns:
            List of (prompt, score) pairs showing optimization progress
        """
        
        optimized_chunk = len(Text._text_to_words(str(initial_prompt)))
        current_text = Text(initial_prompt)

        # Initialize objective
        self.objective = Score(
            model=self.model,
            layer=self.layer,
            target_embedding=self.target_embedding,
            original_text=initial_prompt,
            device=self.device
        )

        optimized_prompts = [(str(current_text), self.objective.score(current_text))]
        output, eval_score = None, None

        
        for _ in tqdm(range(self.max_iterations), desc="Optimizing prompt"):

            current_score = optimized_prompts[-1][1]
            best_score = current_score
            best_text = current_text

            for word_idx in range(optimized_chunk):
                candidates = self.substitutor.suggestions(current_text, word_idx)
                # Try each candidate
                for candidate in candidates:
                    new_text = current_text.replace(word_idx, candidate)
                    score = self.objective.score(new_text)

                    # Update if better
                    if score > best_score:
                        if self.judge is not None:
                            intent_preserved = self.judge.relabel(initial_prompt, str(new_text)).strip().lower() == 'true'
                            if intent_preserved:
                                best_score = score
                                best_text = new_text
                                print("New best score", best_score)
                                print(best_text)
                        
                        else:
                            best_score = score
                            best_text = new_text
                            print("New best score", best_score)
                            print(best_text)

            current_text = best_text
            optimized_prompts.append((str(current_text), best_score))

            if self.evaluator is not None:
                print("\n\nEvaluating")
                output = self.model.generate_tokens(str(best_text)).split(str(best_text))[1]
                print("Output:", output) 
                eval_score = self.evaluator.is_jailbreaks(initial_prompt, output)
                print("Eval score:", eval_score)
                if eval_score == 1:
                    print("Jailbroken")
                    return optimized_prompts, output, eval_score

        return optimized_prompts, output, eval_score